import matplotlib.pyplot as plt

# === 🔢 Données initiales (à modifier par l'élève) ===
n1 = 2.5  # Quantité initiale de Cu²⁺ en mmol
n2 = 6.0  # Quantité initiale de OH⁻ en mmol
x = 2.5   # 🔁 Avancement choisi par l’élève (en mmol)
# ================================================

# Avancement maximal (réaction totale)
x_max = min(n1, n2 / 2)

# Vérification de la validité de x
if x > x_max:
    print("❌ Erreur : l'avancement dépasse ce qui est possible.")
    print("On ne peut pas consommer plus que le réactif limitant disponible!")
else:
    # Fonction pour calculer les quantités à un avancement donné
    def quantites(n1, n2, x):
        return {
            'Cu²⁺': max(n1 - x, 0),
            'OH⁻': max(n2 - 2 * x, 0),
            'Cu(OH)₂': x,
        }

    # Couleurs des espèces chimiques
    couleurs = {
        'Cu²⁺': 'cyan',
        'OH⁻': 'green',
        'Cu(OH)₂': 'pink',
    }

    # Fonction d'affichage
    def plot_state(n1, n2, x):
        etat = quantites(n1, n2, x)
        
        species = ['Cu²⁺', 'OH⁻', 'Cu(OH)₂']
        positions = range(len(species))
        quantites_initiales = [n1, n2, 0]

        plt.figure(figsize=(8, 5))

        # Barres initiales transparentes
        plt.bar(
            positions,
            quantites_initiales,
            width=0.5,
            color='lightgray',
            alpha=0.4,
            label='Quantité initiale'
        )

        # Barres finales
        y = [etat[s] for s in species]
        bars = plt.bar(
            positions,
            y,
            width=0.5,
            color=[couleurs[s] for s in species],
            label='Quantité finale'
        )

        # Ajout des flèches et textes
        for bar, s, n_init in zip(bars, species, quantites_initiales):
            n_final = bar.get_height()
            delta = n_final - n_init
            signe = '+' if delta > 0 else '−'

            x_center = bar.get_x() + bar.get_width() / 2
            y_start = max(n_init, n_final)
            y_end = min(n_init, n_final)
            y_middle = (n_init + n_final) / 2

            if n_init != n_final:
                # Double flèche
                plt.annotate(
                    '',
                    xy=(x_center, y_end),
                    xytext=(x_center, y_start),
                    arrowprops=dict(arrowstyle='<->', color='black', lw=1)
                )

                # Δn à droite de la flèche
                plt.text(
                    x_center + 0.05,
                    y_middle,
                    f"Δn = {signe}{abs(delta):.2f} mmol",
                    ha='left',
                    va='center',
                    fontsize=10,
                    color='black'
                )

        plt.xticks(positions, species)
        plt.ylabel("Quantité de matière (mmol)")
        plt.title(f"Avancement choisi : x = {x:.2f} mmol")
        plt.grid(axis='y', linestyle='--', alpha=0.5)
        plt.legend()
        plt.tight_layout()
        plt.show()
        plt.close()

    # Exécution du tracé
    plot_state(n1, n2, x)
